{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import struct\n",
    "\n",
    "\n",
    "class MnistData:\n",
    "    def __init__(self, train_image_path, train_label_path,\n",
    "                 test_image_path, test_label_path):\n",
    "        # 训练集和验证集的文件路径\n",
    "        self.train_image_path = train_image_path\n",
    "        self.train_label_path = train_label_path\n",
    "        self.test_image_path = test_image_path\n",
    "        self.test_label_path = test_label_path\n",
    "\n",
    "        # 获取训练集和验证集数据\n",
    "        # get_data()方法，参数为0获取训练集数据，参数为1获取验证集\n",
    "        self.train_images, self.train_labels = self.get_data(0)\n",
    "        self.test_images, self.test_labels = self.get_data(1)\n",
    "\n",
    "    def get_data(self, data_type):\n",
    "        if data_type == 0:  # 获取训练集数据\n",
    "            image_path = self.train_image_path\n",
    "            label_path = self.train_label_path\n",
    "        else:  # 获取验证集数据\n",
    "            image_path = self.test_image_path\n",
    "            label_path = self.test_label_path\n",
    "\n",
    "        with open(image_path, 'rb') as file1:\n",
    "            image_file = file1.read()\n",
    "        with open(label_path, 'rb') as file2:\n",
    "            label_file = file2.read()\n",
    "\n",
    "        label_index = 0\n",
    "        image_index = 0\n",
    "        labels = []\n",
    "        images = []\n",
    "\n",
    "        # 读取训练集图像数据文件的文件信息\n",
    "        magic, num_of_datasets, rows, columns = struct.unpack_from('>IIII',\n",
    "                                                                   image_file, image_index)\n",
    "        image_index += struct.calcsize('>IIII')\n",
    "\n",
    "        for i in range(num_of_datasets):\n",
    "            # 读取784个unsigned byte，即一副图像的所有像素值\n",
    "            temp = struct.unpack_from('>784B', image_file, image_index)\n",
    "            # 将读取的像素数据转换成28*28的矩阵\n",
    "            temp = np.reshape(temp, (28, 28))\n",
    "            # 归一化处理\n",
    "            temp = temp / 255\n",
    "            images.append(temp)\n",
    "            image_index += struct.calcsize('>784B')  # 每次增加784B\n",
    "\n",
    "        # 跳过描述信息\n",
    "        label_index += struct.calcsize('>II')\n",
    "        labels = struct.unpack_from('>' + str(num_of_datasets) + 'B', label_file, label_index)\n",
    "\n",
    "        # one-hot\n",
    "        labels = np.eye(10)[np.array(labels)]\n",
    "\n",
    "        return np.array(images), labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "train_image_path = './data/train-images-idx3-ubyte'\n",
    "train_label_path = './data/train-labels-idx1-ubyte'\n",
    "test_image_path = './data/t10k-images-idx3-ubyte'\n",
    "test_label_path = './data/t10k-labels-idx1-ubyte'\n",
    "\n",
    "data = MnistData(train_image_path, train_label_path,\n",
    "                 test_image_path, test_label_path)\n",
    "\n",
    "print(type(data.train_images))\n",
    "print(type(data.train_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "[[0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.01176471 0.07058824 0.07058824 0.07058824 0.49411765 0.53333333\n",
      "  0.68627451 0.10196078 0.65098039 1.         0.96862745 0.49803922\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.11764706 0.14117647 0.36862745 0.60392157\n",
      "  0.66666667 0.99215686 0.99215686 0.99215686 0.99215686 0.99215686\n",
      "  0.88235294 0.6745098  0.99215686 0.94901961 0.76470588 0.25098039\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.19215686 0.93333333 0.99215686 0.99215686 0.99215686\n",
      "  0.99215686 0.99215686 0.99215686 0.99215686 0.99215686 0.98431373\n",
      "  0.36470588 0.32156863 0.32156863 0.21960784 0.15294118 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.07058824 0.85882353 0.99215686 0.99215686 0.99215686\n",
      "  0.99215686 0.99215686 0.77647059 0.71372549 0.96862745 0.94509804\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.31372549 0.61176471 0.41960784 0.99215686\n",
      "  0.99215686 0.80392157 0.04313725 0.         0.16862745 0.60392157\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.05490196 0.00392157 0.60392157\n",
      "  0.99215686 0.35294118 0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.54509804\n",
      "  0.99215686 0.74509804 0.00784314 0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.04313725\n",
      "  0.74509804 0.99215686 0.2745098  0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.1372549  0.94509804 0.88235294 0.62745098 0.42352941 0.00392157\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.31764706 0.94117647 0.99215686 0.99215686 0.46666667\n",
      "  0.09803922 0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.17647059 0.72941176 0.99215686 0.99215686\n",
      "  0.58823529 0.10588235 0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.0627451  0.36470588 0.98823529\n",
      "  0.99215686 0.73333333 0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.97647059\n",
      "  0.99215686 0.97647059 0.25098039 0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.18039216 0.50980392 0.71764706 0.99215686\n",
      "  0.99215686 0.81176471 0.00784314 0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.15294118 0.58039216 0.89803922 0.99215686 0.99215686 0.99215686\n",
      "  0.98039216 0.71372549 0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.09411765 0.44705882\n",
      "  0.86666667 0.99215686 0.99215686 0.99215686 0.99215686 0.78823529\n",
      "  0.30588235 0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.09019608 0.25882353 0.83529412 0.99215686\n",
      "  0.99215686 0.99215686 0.99215686 0.77647059 0.31764706 0.00784314\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.07058824 0.67058824 0.85882353 0.99215686 0.99215686 0.99215686\n",
      "  0.99215686 0.76470588 0.31372549 0.03529412 0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.21568627 0.6745098\n",
      "  0.88627451 0.99215686 0.99215686 0.99215686 0.99215686 0.95686275\n",
      "  0.52156863 0.04313725 0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.53333333 0.99215686\n",
      "  0.99215686 0.99215686 0.83137255 0.52941176 0.51764706 0.0627451\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]]\n",
      "[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "print(type(data.train_images))\n",
    "print(type(data.train_labels))\n",
    "\n",
    "print(type(data.train_images[0]))\n",
    "print(type(data.train_labels[0]))\n",
    "\n",
    "print(data.train_images[0])\n",
    "print(data.train_labels[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_3\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten_3 (Flatten)          (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense_6 (Dense)              (None, 512)               401920    \n",
      "_________________________________________________________________\n",
      "dropout_3 (Dropout)          (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dense_7 (Dense)              (None, 10)                5130      \n",
      "=================================================================\n",
      "Total params: 407,050\n",
      "Trainable params: 407,050\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "model = tf.keras.models.Sequential([\n",
    "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "  tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
    "  tf.keras.layers.Dropout(0.2),\n",
    "  tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0612 05:24:07.434529 139725216646912 tf_logging.py:161] <tensorflow.python.keras.layers.recurrent.UnifiedLSTM object at 0x7f13b8412438>: Note that this layer is not optimized for performance. Please use tf.keras.layers.CuDNNLSTM for better performance on GPU.\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Input 0 of layer unified_lstm_6 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 784]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-8fc208a9e39f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m   \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLSTM\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m784\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactivation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m   \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m   \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactivation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m ])\n\u001b[1;32m      9\u001b[0m model.compile(optimizer='adam',\n",
      "\u001b[0;32m~/anaconda3/envs/apython/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py\u001b[0m in \u001b[0;36m_method_wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    454\u001b[0m     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setattr_tracking\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    455\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 456\u001b[0;31m       \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    457\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    458\u001b[0m       \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setattr_tracking\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprevious_value\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/apython/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, layers, name)\u001b[0m\n\u001b[1;32m    106\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mlayers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    107\u001b[0m       \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlayers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 108\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    110\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/apython/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py\u001b[0m in \u001b[0;36m_method_wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    454\u001b[0m     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setattr_tracking\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    455\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 456\u001b[0;31m       \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    457\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    458\u001b[0m       \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setattr_tracking\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprevious_value\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/apython/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py\u001b[0m in \u001b[0;36madd\u001b[0;34m(self, layer)\u001b[0m\n\u001b[1;32m    185\u001b[0m       \u001b[0;31m# If the model is being built continuously on top of an input layer:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m       \u001b[0;31m# refresh its output.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m       \u001b[0moutput_tensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         raise TypeError('All layers in a Sequential model '\n",
      "\u001b[0;32m~/anaconda3/envs/apython/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, initial_state, constants, **kwargs)\u001b[0m\n\u001b[1;32m    656\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    657\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0minitial_state\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mconstants\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 658\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRNN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    659\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    660\u001b[0m     \u001b[0;31m# If any of `initial_state` or `constants` are specified and are Keras\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/apython/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m    587\u001b[0m         \u001b[0;31m# the corresponding TF subgraph inside `backend.get_graph()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    588\u001b[0m         input_spec.assert_input_compatibility(self.input_spec, inputs,\n\u001b[0;32m--> 589\u001b[0;31m                                               self.name)\n\u001b[0m\u001b[1;32m    590\u001b[0m         \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    591\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/apython/lib/python3.6/site-packages/tensorflow/python/keras/engine/input_spec.py\u001b[0m in \u001b[0;36massert_input_compatibility\u001b[0;34m(input_spec, inputs, layer_name)\u001b[0m\n\u001b[1;32m    122\u001b[0m                          \u001b[0;34m'expected ndim='\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m', found ndim='\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    123\u001b[0m                          \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'. Full shape received: '\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m                          str(x.shape.as_list()))\n\u001b[0m\u001b[1;32m    125\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_ndim\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m       \u001b[0mndim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndims\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Input 0 of layer unified_lstm_6 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 784]"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "model = tf.keras.models.Sequential([\n",
    "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "  tf.keras.layers.LSTM(128, input_shape=(784, 1), activation=tf.nn.relu),\n",
    "  tf.keras.layers.Dropout(0.2),\n",
    "  tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 6s 97us/sample - loss: 0.2224 - accuracy: 0.9337\n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 6s 96us/sample - loss: 0.0961 - accuracy: 0.9705\n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 6s 94us/sample - loss: 0.0691 - accuracy: 0.9776\n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 5s 92us/sample - loss: 0.0523 - accuracy: 0.9836\n",
      "Epoch 5/10\n",
      "60000/60000 [==============================] - 5s 91us/sample - loss: 0.0416 - accuracy: 0.9865\n",
      "Epoch 6/10\n",
      "60000/60000 [==============================] - 5s 91us/sample - loss: 0.0369 - accuracy: 0.9878\n",
      "Epoch 7/10\n",
      "60000/60000 [==============================] - 5s 91us/sample - loss: 0.0308 - accuracy: 0.9895\n",
      "Epoch 8/10\n",
      "60000/60000 [==============================] - 5s 92us/sample - loss: 0.0257 - accuracy: 0.9911\n",
      "Epoch 9/10\n",
      "60000/60000 [==============================] - 6s 94us/sample - loss: 0.0255 - accuracy: 0.9915\n",
      "Epoch 10/10\n",
      "60000/60000 [==============================] - 6s 96us/sample - loss: 0.0220 - accuracy: 0.9926\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f19b49d72e8>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(data.train_images, data.train_labels, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000/10000 [==============================] - 1s 66us/sample - loss: 0.0731 - accuracy: 0.9817\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.07307186364395202, 0.9817]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(data.test_images, data.test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 6s 97us/sample - loss: 0.2205 - accuracy: 0.9349\n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 6s 94us/sample - loss: 0.0982 - accuracy: 0.9704\n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 6s 92us/sample - loss: 0.0694 - accuracy: 0.9782\n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 6s 98us/sample - loss: 0.0531 - accuracy: 0.9827\n",
      "Epoch 5/10\n",
      "60000/60000 [==============================] - 6s 93us/sample - loss: 0.0437 - accuracy: 0.9855\n",
      "Epoch 6/10\n",
      "60000/60000 [==============================] - 5s 89us/sample - loss: 0.0362 - accuracy: 0.9883\n",
      "Epoch 7/10\n",
      "60000/60000 [==============================] - 5s 91us/sample - loss: 0.0291 - accuracy: 0.9897\n",
      "Epoch 8/10\n",
      "60000/60000 [==============================] - 5s 90us/sample - loss: 0.0273 - accuracy: 0.9907\n",
      "Epoch 9/10\n",
      "60000/60000 [==============================] - 5s 90us/sample - loss: 0.0241 - accuracy: 0.9919\n",
      "Epoch 10/10\n",
      "60000/60000 [==============================] - 5s 90us/sample - loss: 0.0223 - accuracy: 0.9921\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f192f04b898>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "model = tf.keras.models.Sequential([\n",
    "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "  tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
    "  tf.keras.layers.Dropout(0.2),\n",
    "  tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "model.fit(data.train_images, data.train_labels, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "5\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "\n",
    "print(type(x_train))\n",
    "print(type(x_test))\n",
    "\n",
    "# print(x_train[0])\n",
    "print(y_train[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}